% This file reproduces Figure 9 in the paper

close all
clear

global grid_size
global rho
global gamma_x
global gamma_s
global delta
global alpha
global c_x_1
global c_x_2
global c_s_1
global c_s_2
global x
global s
global x_dot_bar
global s_dot_bar

grid_size = 501;
rho = 1000;
delta = .1;
alpha = 2; % convexity parameter
x_dot_bar = 10-delta; % upper bound on x_dot
s_dot_bar = 10-delta; % upper bound on s_dot
error_tol = 10e-12;

% cost functions 
% c(z) = c_1*z+c_2*z^alpha

c_x_1 = 0;
c_s_1 = 0;

% Assumption 3
gamma_x = min(.4*(PDF(0)-PDF(1)),1)
gamma_s = min(.4*(PDF(0)-PDF(1)),1)

% Assumption 3
convex_weight_x = .1;
convex_weight_s = .1;
c_x_2 = (convex_weight_x*PDF(1) + (1-convex_weight_x)*min(PDF(0)-gamma_x,PDF(gamma_x)))/(alpha*delta^(alpha-1));
c_s_2 = (convex_weight_s*PDF(1) + (1-convex_weight_s)*min(PDF(0)-gamma_s,PDF(gamma_s)))/(alpha*delta^(alpha-1));

c_x_2 = 9
c_s_2 = 9

% the grid
x = linspace(0,1,grid_size)'*ones(1,grid_size);
s = ones(grid_size,1)*linspace(0,1,grid_size);

% initial value and policy functions
x_dot = zeros(grid_size,grid_size);
s_dot = zeros(grid_size,grid_size);
V_x = zeros(grid_size,grid_size);
V_s = zeros(grid_size,grid_size);

error = error_tol*10;

while error > error_tol
    [dV_x_over_ds,dV_x_over_dx] = gradient(V_x,1/(grid_size-1)/2,1/(grid_size-1)/2);
    [dV_s_over_ds,dV_s_over_dx] = gradient(V_s,1/(grid_size-1)/2,1/(grid_size-1)/2);
     
    x_dot_new = x_dot_max(dV_x_over_dx);
    s_dot_new = s_dot_max(dV_s_over_ds);
    
    V_x_new = CDF(x-s)+(1/rho)*(-C_x(x_dot_new)+dV_x_over_dx.*x_dot_new+dV_x_over_ds.*s_dot_new);
    V_s_new = CDF(s-x)+(1/rho)*(-C_s(s_dot_new)+dV_s_over_dx.*x_dot_new+dV_s_over_ds.*s_dot_new);
    
    error = norm(V_x_new-V_x,Inf)+norm(V_s_new-V_s,Inf)+norm(x_dot_new-x_dot,Inf)+norm(s_dot_new-s_dot,Inf);
    
    V_x = V_x_new;
    V_s = V_s_new;
    x_dot = x_dot_new;
    s_dot = s_dot_new;
end
x_dot(1,:) = max(x_dot(1,:),zeros(1,grid_size));
s_dot(:,1) = max(s_dot(:,1),zeros(grid_size,1));
x_dot(grid_size,:) = min(x_dot(1,:),zeros(1,grid_size));
s_dot(:,grid_size) = min(s_dot(:,1),zeros(grid_size,1));

figure
hold on
axis equal
axis tight

% streamslice(s(1,:),x(:,1),s_dot,x_dot)

yy = linspace(0,.16,100);
xx = yy;
plot(yy,xx,'r','LineWidth',1.1)

yy = linspace(.16,.405,500);
xx = .16+.9*(yy-.16)+3.7*(yy-.16).^2;
indices = ((xx<=1) & (xx>=0));
plot(yy(indices),xx(indices),'r','LineWidth',1.1)

xx = linspace(.16,.405,500);
yy = .16+.9*(xx-.16)+3.7*(xx-.16).^2;
indices = ((yy<=1) & (yy>=0));
plot(yy(indices),xx(indices),'r','LineWidth',1.1)

yy = linspace(0.405,2,100);
xx = yy-0.405+.16+.9*(.405-.16)+3.7*(.405-.16).^2;
indices = ((xx<=1) & (xx>=0));
plot(yy(indices),xx(indices),'r','LineWidth',1.1)

xx = linspace(0.405,1,100);
yy = xx-0.405+.16+.9*(.405-.16)+3.7*(.405-.16).^2;
indices = ((yy<=1) & (yy>=0));
plot(yy(indices),xx(indices),'r','LineWidth',1.1)

x_dot_sub1 = zeros(grid_size,grid_size);
s_dot_sub1 = zeros(grid_size,grid_size);
for ss=1:grid_size
    for xx=1:grid_size
        if ((x(ss,xx)>s(ss,xx)+0.02) && (s(ss,xx)<.16)) || (x(ss,xx)>.18+.9*(s(ss,xx)-.16)+3.7*(s(ss,xx)-.16).^2) || ((x(ss,xx)>s(ss,xx)-0.405+.16+.9*(.405-.16)+3.7*(.405-.16).^2) && (s(ss,xx)>0.405))
            x_dot_sub1(ss,xx) = x_dot(ss,xx);
            s_dot_sub1(ss,xx) = s_dot(ss,xx);
        end
    end
end
h = streamslice(s(1,:),x(:,1),s_dot_sub1,x_dot_sub1);
set(h, 'Color', 'm')

x_dot_sub2 = zeros(grid_size,grid_size);
s_dot_sub2 = zeros(grid_size,grid_size);
for ss=1:grid_size
    for xx=1:grid_size
        if ((s(ss,xx)>x(ss,xx)+0.02) && (x(ss,xx)<.16)) || (s(ss,xx)>.18+.9*(x(ss,xx)-.16)+3.7*(x(ss,xx)-.16).^2) || ((s(ss,xx)>x(ss,xx)-0.405+.16+.9*(.405-.16)+3.7*(.405-.16).^2) && (x(ss,xx)>0.405))
            x_dot_sub2(ss,xx) = x_dot(ss,xx);
            s_dot_sub2(ss,xx) = s_dot(ss,xx);
        end
    end
end
h = streamslice(s(1,:),x(:,1),s_dot_sub2,x_dot_sub2);
set(h, 'Color', 'b')

x_dot_sub = zeros(grid_size,grid_size);
s_dot_sub = zeros(grid_size,grid_size);
for ss=1:grid_size
    for xx=1:grid_size
        if (x_dot_sub1(ss,xx)==0) && (x_dot_sub2(ss,xx)==0) && ((x(ss,xx)<.161+.9*(s(ss,xx)-.16)+3.7*(s(ss,xx)-.16).^2)) && (s(ss,xx)<.161+.9*(x(ss,xx)-.16)+3.7*(x(ss,xx)-.16).^2)
            x_dot_sub(ss,xx) = x_dot(ss,xx);
            s_dot_sub(ss,xx) = s_dot(ss,xx);
        end
    end
end
h = streamslice(s(1,:),x(:,1),s_dot_sub,x_dot_sub);
set(h, 'Color', 'k')

axis([-0.02 1.02 -0.02 1.02])
xlabel('x')
ylabel('s')
set(get(gca,'YLabel'),'Rotation',0)
print('figure9','-depsc','-painters')